name: Build AutoGPTQ Wheels with CUDA for Windows

on: workflow_dispatch

jobs:
  build_wheels:
    if: ${{ github.repository_owner == 'AutoGPTQ' }}
    name: Build wheels for ${{ matrix.os }} and Python ${{ matrix.python }} and CUDA ${{ matrix.cuda }}
    runs-on: ${{ matrix.os }}
    strategy:
      matrix:
        os: [windows-latest]
        pyver: ["3.8", "3.9", "3.10", "3.11"]
        cuda: ["11.8"]  # wheel for 12.1 are built in build_wheels_pypi.yml
    defaults:
      run:
        shell: pwsh
    env:
        CUDA_VERSION: ${{ matrix.cuda }}

    steps:
      - uses: actions/checkout@v3

      - uses: actions/setup-python@v3
        with:
          python-version: ${{ matrix.pyver }}

      - name: Setup Miniconda
        uses: conda-incubator/setup-miniconda@v2.2.0
        with:
          activate-environment: "build"
          python-version: ${{ matrix.pyver }}
          mamba-version: "*"
          use-mamba: false
          channels: conda-forge,defaults
          channel-priority: true
          add-pip-as-python-dependency: true
          auto-activate-base: false

      - name: Install Dependencies
        run: |
          conda install cuda-toolkit -c "nvidia/label/cuda-${env:CUDA_VERSION}.0"

          # Refer to https://pytorch.org/get-started/locally/
          python -m pip install torch --index-url https://download.pytorch.org/whl/cu118
          python -m pip install --upgrade build setuptools wheel ninja numpy gekko pandas

      - name: Check install
        run: |
          python -c "import torch; print('torch version:', torch.__version__)"

      - name: Build Wheel
        run: |
          $env:CUDA_PATH = $env:CONDA_PREFIX
          $env:CUDA_HOME = $env:CONDA_PREFIX

          $env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6+PTX'
          if ([decimal]$env:CUDA_VERSION -ge 11.8) { $env:TORCH_CUDA_ARCH_LIST = '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' }
          python setup.py sdist bdist_wheel

      - uses: actions/upload-artifact@v3
        with:
          name: 'windows-cuda-wheels'
          path: ./dist/*.whl
